import logging
from init import PATHS
LOGGER = logging.getLogger(__name__)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def main():
	# Import policies and get data at country-year level
	strategies = pd.read_csv(PATHS.rawdata / 'PolicyLandscaping/country_tech_strategy_timeline_table.csv')
	strategies = transform_strategy_df_as_panel_country_year(strategies)
	strategies = strategies[strategies['Year'] >= 1995]
	# Import R&D Data
	rdd = read_rdd()
	# Merge policies together
	policies = strategies.merge(rdd, on=['Country', 'Year'], how='outer')
	policies = policies[policies['Country'] != 'Germany/foeka']
	# Import firn sales data
	firm_country_shareofsales = get_sales_data()
	# Use share in 2004 as refrence point
	firm_country_shareofsales = firm_country_shareofsales[firm_country_shareofsales['Year'] == 2004]
	firm_country_shareofsales = firm_country_shareofsales.drop(columns=['Year'])
	# Merge
	dffirmcountryyear = firm_country_shareofsales.merge(policies, on='Country', how='outer')
	# Construct firm-level exposure to strategy at the year level
	ExposureCols = ['strat_fc_code', 'strat_bev_code', 'strat_focus_code', 'fchydrogen', 'otherstorage', 'otherstorageIEA', 'otherstoragesmooth', 'otherstoragesmooth2']
	NewExposureCols = []
	for col in ExposureCols:
		# there may be some missing policy data. e.g., some countries we have no RD budget
		# so we need to rescale the share such that the shares sum up to 1 within the countries we have
		mask = dffirmcountryyear[col].notnull()
		totalweight = dffirmcountryyear[mask].groupby(['OEM_Level1_ID', 'Year'])['Share_of_Country_for_OEMSales'].sum().rename('totalweight')
		dffirmcountryyear = dffirmcountryyear.merge(totalweight, on=['OEM_Level1_ID', 'Year'], how='left')
		dffirmcountryyear['scaled_share'] = dffirmcountryyear['Share_of_Country_for_OEMSales'] / dffirmcountryyear['totalweight']
		if 'code' in col:
			col_newname = col.split('_code')[0] + '_exposure'
		else:
			col_newname = col + '_RDexposure'
		NewExposureCols.append(col_newname)
		# we use the share of sales in country i as a weight. on the next line we multiply and then after we agg by summing so it's like a weighted average
		dffirmcountryyear[col_newname] = dffirmcountryyear['scaled_share'] * dffirmcountryyear[col]
		dffirmcountryyear = dffirmcountryyear.drop(columns=['totalweight', 'scaled_share'])
	dffirmyear = dffirmcountryyear.groupby(['OEM_Level1_ID', 'Level1_MLName', 'Year'])[NewExposureCols].sum()
	dffirmyear = dffirmyear.reset_index()
	dffirmyear.to_csv(PATHS.dropbox / 'Data_outputted/D_Policy/firm_year_policy_exposure.csv', index=False)


def rescale_col(group, col):
	sum_weights = group.loc[group[col].notnull(), 'Share_of_Country_for_OEMSales'].sum()
	# Rescale the column by the sum of weights
	group[col] /= sum_weights
	return group


def read_rdd():
	rdd_0 = pd.read_csv(PATHS.dropbox / 'Data_outputted/D_Policy/RDD1995_2020.csv')
	# main column with USA DOE
	rdd = rdd_0.rename(columns={'USA/DOE': 'USA'})
	rdd = rdd.drop(columns=['USA/IEA'])
	rdd['technology'] = rdd['technology'].replace('HGENCELL','fchydrogen')
	rdd['technology'] = rdd['technology'].replace('OTHERPANDS','otherstorage')
	rdd = pd.melt(rdd, id_vars=['year', 'technology'], var_name='country', value_name='rdd_public_spending')
	rdd = rdd.rename(columns={'country': 'Country', 'year': 'Year'})
	rdd = rdd.pivot(index=['Country', 'Year'], columns='technology', values='rdd_public_spending')
	rdd = rdd.reset_index()
	# add another column with USA IEA
	rddiea = rdd_0.rename(columns={'USA/IEA': 'USA'})  # can do a robustness using the USA/IEA column instead
	rddiea = rddiea.drop(columns=['USA/DOE'])
	rddiea = rddiea[rddiea['technology'] == 'OTHERPANDS']
	rddiea['technology'] = rddiea['technology'].replace('OTHERPANDS','otherstorage')
	rddiea = pd.melt(rddiea, id_vars=['year', 'technology'], var_name='country', value_name='rdd_public_spending')
	rddiea = rddiea.rename(columns={'country': 'Country', 'year': 'Year'})
	rddiea = rddiea.pivot(index=['Country', 'Year'], columns='technology', values='rdd_public_spending').reset_index()
	rddiea = rddiea.rename(columns={'otherstorage': 'otherstorageIEA'})
	rdd = rdd.merge(rddiea, on=['Country', 'Year'], how='outer')
	rdd['Country'] = rdd['Country'].apply(lambda x: x.capitalize()if x not in ['USA', 'UK'] else x)
	# adding an alternative RD column smoothing ARA
	R2008 = rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2008)]['otherstorage'].iloc[0]
	R2009 = rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2009)]['otherstorage'].iloc[0]
	R2010 = rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2010)]['otherstorage'].iloc[0]
	R2011 = rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2011)]['otherstorage'].iloc[0]
	ARA = R2009 - (R2010-R2008)/2
	# Split ARA amount over the next 4 years
	rdd['otherstoragesmooth'] = rdd['otherstorage']
	rdd.loc[rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2009)].index, 'otherstoragesmooth'] = (R2010-R2008)/2 + 0.5 * ARA
	rdd.loc[rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2010)].index, 'otherstoragesmooth'] = R2010 + 0.3 * ARA
	rdd.loc[rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2011)].index, 'otherstoragesmooth'] = R2011 + 0.2 * ARA
	# other trick: same amount for 3 years
	rdd['otherstoragesmooth2'] = rdd['otherstorage']
	rdd.loc[rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2009)].index, 'otherstoragesmooth2'] = (R2010-R2008)/2 + ARA/3
	rdd.loc[rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2010)].index, 'otherstoragesmooth2'] = R2010 + ARA/3
	rdd.loc[rdd[(rdd['Country'] == 'USA') & (rdd['Year'] == 2011)].index, 'otherstoragesmooth2'] = R2011 + ARA/3
	return rdd


def transform_strategy_df_as_panel_country_year(strategies):
	policy_df = []
	for ix in strategies.index:
		start_date = strategies.iloc[ix]['start.date']
		end_date = strategies.iloc[ix]['end.date.nooverlap']
		d = pd.DataFrame({'year': np.arange(start_date, end_date + 1)})
		d['country'] = strategies.iloc[ix]['country']
		d['strat_fc_code'] = strategies.iloc[ix]['strat_fc_code']
		d['strat_bev_code'] = strategies.iloc[ix]['strat_bev_code']
		d['strat_focus_code'] = strategies.iloc[ix]['strat_focus_code']
		policy_df.append(d)
	policy_df = pd.concat(policy_df)
	policy_df.set_index(['country', 'year'], inplace=True)
	## fill in with 0 for years that aren't coded
	index = pd.MultiIndex.from_product(policy_df.index.levels)
	policy_df = policy_df.reindex(index)
	policy_df = policy_df.reset_index()
	policy_df = policy_df.fillna(0)
	policy_df = policy_df.rename(columns={'country': 'Country', 'year': 'Year'})
	return policy_df


def get_sales_data():
	# Firm geographic sales
	firm_sales_by_country = pd.read_csv(PATHS.dropbox / 'Data_outputted/A_AutoIndustry/sales_firm_year_country_level.csv')
	firm_sales_by_country = firm_sales_by_country[firm_sales_by_country['Country'].isin(['Germany', 'USA', 'China', 'France', 'UK', 'Japan', 'Korea'])]
	firm_sales_by_country = firm_sales_by_country[firm_sales_by_country['value'] > 0]
	firm_sales_by_country = firm_sales_by_country[firm_sales_by_country['value'].notnull()]
	globaloemsales = firm_sales_by_country.groupby(['Year', 'Level1_MLName', 'OEM_Level1_ID'])['value'].sum().reset_index().rename(columns={'value': 'OEMGlobalSales'})
	totalcountrysales = firm_sales_by_country.groupby(['Year', 'Country'])['value'].sum().reset_index().rename(columns={'value': 'TotalCountrySales'})
	firm_sales_by_country = firm_sales_by_country.merge(globaloemsales, on=['Year', 'Level1_MLName', 'OEM_Level1_ID'], how='inner')
	firm_sales_by_country = firm_sales_by_country.merge(totalcountrysales, on = ['Year','Country'], how = 'inner')
	firm_sales_by_country = firm_sales_by_country.rename(columns={'value': 'OEMCountrySales'})
	firm_sales_by_country['Share_of_Country_for_OEMSales'] = firm_sales_by_country['OEMCountrySales']/firm_sales_by_country['OEMGlobalSales']
	firm_sales_by_country['Share_of_OEM_for_CountrySales'] = firm_sales_by_country['OEMCountrySales']/firm_sales_by_country['TotalCountrySales']
	firm_sales_by_country = firm_sales_by_country[['Year', 'Level1_MLName', 'OEM_Level1_ID', 'Country', 'Share_of_Country_for_OEMSales','Share_of_OEM_for_CountrySales']].drop_duplicates()
	return firm_sales_by_country



def plot_rdd_series_US_DE():
	rdd = read_rdd()
	fig1 = plt.figure() # size in inches
	ax1 = fig1.add_subplot(121) 
	ax1.plot(rdd[rdd['Country']=='USA']['Year'],rdd[rdd['Country']=='USA']['fchydrogen'], label = 'US/DOE') # this is the USA/DOE, i.e. Anadon data
	ax1.plot(rdd[rdd['Country']=='Usa/iea']['Year'],rdd[rdd['Country']=='Usa/iea']['fchydrogen'], label = 'US/IEA')
	ax1.set_title('FCHydrogen USA')
	ax1.legend()
	ax2 = fig1.add_subplot(122)
	ax2.plot(rdd[rdd['Country']=='USA']['Year'],rdd[rdd['Country']=='USA']['otherstorage'], label = 'US/DOE')
	ax2.plot(rdd[rdd['Country']=='Usa/iea']['Year'],rdd[rdd['Country']=='Usa/iea']['otherstorage'], label = 'US/IEA')
	ax2.set_title('Other storage USA')
	ax2.legend()
	fig1.savefig(PATHS.dropbox / 'Data_outputted/D_Policy/USA_time_series_by_data_source.png')
	plt.close()
	fig2 = plt.figure() 
	plt.plot(rdd[rdd['Country']=='Germany']['Year'],rdd[rdd['Country']=='Germany']['fchydrogen'], label = 'IEA')
	plt.plot(rdd[rdd['Country']=='Germany/foeka']['Year'],rdd[rdd['Country']=='Germany/foeka']['fchydrogen'], label = 'Foeka')
	plt.legend()
	plt.title('Fuel Cell Germany')
	fig2.savefig(PATHS.dropbox / 'Data_outputted/D_Policy/Germany_time_series_by_data_source.png')
	plt.close()


